import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange

from core.models.modules.rope import RotaryPositionalEmbeddings


class CapTalkGen(nn.Module):
    def __init__(self, embed_dim, audio_dim, caption_dim, num_heads, depth, patch_nums):
        super(CapTalkGen, self).__init__()
        self.attn_depth = depth
        self.patch_nums = patch_nums
        drop_rate = [x.item() for x in torch.linspace(0, 0.1 * depth / 24, depth)]
        self.attn_blocks = nn.ModuleList(
            [
                SelfCrossAttn(
                    embed_dim=embed_dim,
                    audio_dim=audio_dim,
                    caption_dim=caption_dim,
                    num_heads=num_heads,
                    drop_path=drop_rate[depth_idx],
                )
                for depth_idx in range(depth)
            ]
        )
        self.lvl_embed = nn.Embedding(len(patch_nums), embed_dim)
        nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=math.sqrt(1 / embed_dim / 3))
        self.style_embed = nn.Parameter(torch.zeros(1, 1, caption_dim))
        nn.init.trunc_normal_(self.style_embed, mean=0, std=math.sqrt(1 / caption_dim / 3))
        attn_masking, attn_rope_pos, lvl_idx = self.build_attn_mask(patch_nums)
        self.register_buffer("attn_masking", attn_masking, persistent=False)
        self.register_buffer("attn_rope_pos", attn_rope_pos, persistent=False)
        self.register_buffer("lvl_idx", lvl_idx, persistent=False)
        self.audio_rope_pos = [max(patch_nums), max(patch_nums)]

    def forward(
        self,
        feat,
        audio_feat,
        prev_feat,
        video_caption_feat,
        audio_caption_feat,
        video_caption_feat_mask,
        audio_caption_feat_mask,
    ):
        batch, curr_seq_len, _ = feat.shape
        lvl_embed = self.lvl_embed(self.lvl_idx)
        attn_feat = torch.cat(
            [
                prev_feat + lvl_embed,
                feat + lvl_embed[:, :curr_seq_len],
            ],
            dim=1,
        )
        for bidx in range(self.attn_depth):
            attn_feat = self.attn_blocks[bidx](
                attn_feat,
                audio_feat,
                video_caption_feat + self.style_embed,
                audio_caption_feat + self.style_embed,
                video_caption_feat_mask,
                audio_caption_feat_mask,
                attn_rope_pos=self.attn_rope_pos,
                audio_rope_pos=self.audio_rope_pos,
                self_attn_bias=self.attn_masking,
            )
        curr_feat = attn_feat[:, -curr_seq_len:, :]
        return curr_feat

    @torch.no_grad()
    def build_attn_mask(self, patch_nums, expand=True):
        L = sum(patch_nums)
        d = torch.cat([torch.full((pn,), i) for i, pn in enumerate(patch_nums)]).view(1, L, 1)
        dT = d.transpose(1, 2)  # dT: 11L
        lvl_idx = dT[:, 0].contiguous()
        attn_masking = torch.where(d >= dT, 0.0, -torch.inf).reshape(1, 1, L, L).contiguous()
        patch_len = max(patch_nums)
        patchs_ids = [torch.linspace(0, patch_len - 1, pn * 2 + 1)[1::2].round().long() for pn in patch_nums]
        rope_pos = torch.cat(patchs_ids, dim=0)
        if expand:
            zeros_masking = torch.zeros_like(attn_masking)
            minf_masking = torch.full_like(attn_masking, -torch.inf)
            line_0 = torch.cat([attn_masking, minf_masking], dim=-1)
            line_1 = torch.cat([zeros_masking, attn_masking], dim=-1)
            attn_masking = torch.cat([line_0, line_1], dim=-2)
            rope_pos = torch.cat([rope_pos, rope_pos + patch_len])
        return attn_masking, rope_pos, lvl_idx


class SelfCrossAttn(nn.Module):
    def __init__(self, embed_dim, audio_dim, caption_dim, num_heads, drop_path=0.0):
        super(SelfCrossAttn, self).__init__()
        hidden_features = round(embed_dim * 4.0)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.self_attn = SelfAttention(embed_dim=embed_dim, num_heads=num_heads)
        self.cross_attn = CrossAttention(
            embed_dim=embed_dim, context_embed_dim=audio_dim, num_heads=num_heads
        )
        self.vcaption_cross_attn = CrossAttention(
            embed_dim=embed_dim, context_embed_dim=caption_dim, num_heads=num_heads
        )
        self.acaption_cross_attn = CrossAttention(
            embed_dim=embed_dim, context_embed_dim=caption_dim, num_heads=num_heads
        )
        self.ffn = torch.nn.Sequential(
            nn.LayerNorm(embed_dim, elementwise_affine=True, eps=1e-6),
            nn.Linear(embed_dim, hidden_features),
            nn.GELU(approximate="tanh"),
            nn.Linear(hidden_features, embed_dim),
        )

    def forward(
        self,
        feat,
        audio_feat,
        video_caption_feat,
        audio_caption_feat,
        video_caption_feat_mask,
        audio_caption_feat_mask,
        attn_rope_pos,
        audio_rope_pos,
        self_attn_bias,
    ):
        feat = feat + self.drop_path(self.self_attn(feat, attn_rope_pos, self_attn_bias))
        feat = feat + self.drop_path(self.cross_attn(feat, audio_feat, attn_rope_pos, audio_rope_pos))
        feat = feat + self.drop_path(
            self.vcaption_cross_attn(feat, video_caption_feat, attn_rope_pos, None, video_caption_feat_mask)
        )
        feat = feat + self.drop_path(
            self.acaption_cross_attn(feat, audio_caption_feat, attn_rope_pos, None, audio_caption_feat_mask)
        )
        feat = feat + self.drop_path(self.ffn(feat))
        return feat


class SelfAttention(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12):
        super().__init__()
        self.num_heads = num_heads
        self.scale = int(embed_dim) ** (-0.5)
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.rearrange_qkv = Rearrange("b n (qkv h d) -> qkv b n h d", qkv=3, h=self.num_heads)
        self.rearrange_rope = Rearrange("b n h d -> b h n d")
        self.rearrange_out = Rearrange("b h n d -> b n (h d)")
        self.to_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.to_out = nn.Linear(embed_dim, embed_dim)
        # rotary positional embeddings
        self.rope = RotaryPositionalEmbeddings(dim=embed_dim // num_heads, max_seq_len=500)
        # layer norm
        self.norm = nn.LayerNorm(embed_dim, elementwise_affine=True, eps=1e-6)

    def forward(self, x, rope_pos, attn_bias):
        _, curr_len, _ = x.shape
        this_rope_pos = rope_pos[:curr_len]
        this_attn_bias = attn_bias[:, :, :curr_len, :curr_len]

        qkv = self.to_qkv(self.norm(x))
        q, k, v = self.rearrange_qkv(qkv).unbind(0)  # b n h d
        q = self.rearrange_rope(self.rope(q, input_pos=this_rope_pos))
        k = self.rearrange_rope(self.rope(k, input_pos=this_rope_pos))
        v = self.rearrange_rope(v)
        # compute attention
        out = torch.nn.functional.scaled_dot_product_attention(
            query=q, key=k, value=v, scale=self.scale, attn_mask=this_attn_bias
        )
        out = self.rearrange_out(out)
        out = self.to_out(out)
        return out


class CrossAttention(nn.Module):
    def __init__(self, embed_dim=768, context_embed_dim=768, num_heads=12):
        super().__init__()
        self.num_heads = num_heads
        self.scale = int(embed_dim) ** (-0.5)
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.rearrange_qkv = Rearrange("b n (h d) -> b n h d", h=self.num_heads)
        self.rearrange_rope = Rearrange("b n h d -> b h n d")
        self.rearrange_out = Rearrange("b h n d -> b n (h d)")
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(context_embed_dim, embed_dim)
        self.v_proj = nn.Linear(context_embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        # rotary positional embeddings
        self.rope = RotaryPositionalEmbeddings(dim=embed_dim // num_heads, max_seq_len=500)
        # layer norm
        self.self_norm = nn.LayerNorm(embed_dim, elementwise_affine=True, eps=1e-6)
        self.context_norm = nn.LayerNorm(context_embed_dim, elementwise_affine=True, eps=1e-6)

    def forward(self, x, context, rope_pos, context_rope_pos, attn_bias=None):
        _, curr_len, _ = x.shape
        this_rope_pos = rope_pos[:curr_len]
        x, context = self.self_norm(x), self.context_norm(context)
        if context_rope_pos is not None:
            patch_len, patch_offsets = context_rope_pos
            context_rope_pos = (
                torch.linspace(0, patch_len, steps=(context.shape[1] + 1))[:-1].long() + patch_offsets
            )
        q = self.rearrange_qkv(self.q_proj(x))
        k = self.rearrange_qkv(self.k_proj(context))
        v = self.rearrange_qkv(self.v_proj(context))
        q = self.rearrange_rope(self.rope(q, input_pos=this_rope_pos))
        k = self.rearrange_rope(self.rope(k, input_pos=context_rope_pos))
        v = self.rearrange_rope(v)

        # scaled_dot_product_attention inputs shape (B, Hq, L, c)
        out = torch.nn.functional.scaled_dot_product_attention(
            query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias
        )
        out = self.rearrange_out(out)
        out = self.out_proj(out)
        return out
